from __future__ import annotations

from dataclasses import dataclass
from functools import lru_cache
from typing import Callable, Optional, Tuple

import torch
import torch.nn.functional as F
from jaxtyping import Bool, Float32, Int64
from torch import Tensor

from rfstudio.utils.tensor_dataclass import Float, Long, Size, TensorDataclass

from ._triangle_mesh import TriangleMesh


@lru_cache(maxsize=64)
def _get_cube_edges(device: torch.device) -> Int64[Tensor, "12*2"]:
    return torch.tensor([
        [0, 1],
        [1, 5],
        [4, 5],
        [0, 4],
        [2, 3],
        [3, 7],
        [6, 7],
        [2, 6],
        [2, 0],
        [3, 1],
        [7, 5],
        [6, 4],
    ], dtype=torch.long).flatten().to(device)


@lru_cache(maxsize=64)
def _get_check_table(device: torch.device) -> Int64[Tensor, "256 5"]:
    check_table = torch.zeros(256, 5, dtype=torch.long)
    nonempty_values = [
        [1, 1, 0, 0, 194],
        [1, -1, 0, 0, 193],
        [1, 0, 1, 0, 164],
        [1, 0, -1, 0, 161],
        [1, 0, 0, 1, 152],
        [1, 0, 0, 1, 145],
        [1, 0, 0, 1, 144],
        [1, 0, 0, -1, 137],
        [1, 0, 1, 0, 133],
        [1, 0, 1, 0, 132],
        [1, 1, 0, 0, 131],
        [1, 1, 0, 0, 130],
        [1, 0, 0, 1, 100],
        [1, 0, 0, 1, 98],
        [1, 0, 0, 1, 96],
        [1, 0, 1, 0, 88],
        [1, 0, -1, 0, 82],
        [1, 0, 1, 0, 74],
        [1, 0, 1, 0, 72],
        [1, 0, 0, -1, 70],
        [1, -1, 0, 0, 67],
        [1, -1, 0, 0, 65],
        [1, 1, 0, 0, 56],
        [1, -1, 0, 0, 52],
        [1, 1, 0, 0, 44],
        [1, 1, 0, 0, 40],
        [1, 0, 0, -1, 38],
        [1, 0, -1, 0, 37],
        [1, 0, -1, 0, 33],
        [1, -1, 0, 0, 28],
        [1, 0, -1, 0, 26],
        [1, 0, 0, -1, 25],
        [1, -1, 0, 0, 20],
        [1, 0, -1, 0, 18],
        [1, 0, 0, -1, 9],
        [1, 0, 0, -1, 6],
    ]
    for item in nonempty_values:
        check_table[255 - item[-1]] = torch.tensor(item, dtype=torch.long)
    return check_table.to(device)


@lru_cache(maxsize=64)
def _get_dmc_table(device: torch.device) -> Int64[Tensor, "256 4 7"]:
    _ = -1
    A = 10
    B = 11
    return torch.tensor([
        [[_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 8, 9, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [4, 7, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 7, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [4, 5, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 5, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 5, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 7, 8, 9, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 5, 7, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 5, 7, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 5, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 8, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 8, 9, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 7, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [4, 7, 8, _, _, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 7, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 8, B, _, _, _], [4, 5, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 5, _, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 5, 8, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 7, 8, 9, _, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 5, 7, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 5, 7, 8, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 5, 7, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 9, A, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 8, 9, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 7, _, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 9, A, _, _, _], [4, 7, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 7, 9, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [4, 5, 9, _, _, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 5, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 5, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 7, 8, 9, _, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 5, 7, 9, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 5, 7, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 5, 7, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 8, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 9, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[8, 9, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [1, 3, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 7, A, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 9, A, B, _, _], [4, 7, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 9, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [1, 3, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 8, A, B, _, _], [4, 5, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 5, A, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 8, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 7, 8, 9, _, _, _], [1, 3, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 5, 7, 9, A, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 5, 7, 8, A, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 7, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 8, 9, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 6, 8, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 6, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [4, 6, 8, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 6, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [4, 5, 9, _, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 5, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 5, 8, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 6, 8, 9, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 5, 6, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 5, 6, 8, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 5, 6, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 6, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 6, 7, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [2, 3, 6, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 6, 7, 8, 9, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 6, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 6, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [2, 3, 4, 6, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 6, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [2, 3, 6, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 6, 7, 8, _, _], [4, 5, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 5, _, _, _], [2, 3, 6, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 5, 6, 7, 8], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 5, 6, 8, 9, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 5, 6, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 2, 3, 5, 6, 8], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 5, 6, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, A, _, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [1, 2, A, _, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 9, A, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 8, 9, A, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 6, 8, B, _, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 6, B, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 9, A, _, _, _], [4, 6, 8, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 6, 9, A, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [1, 2, A, _, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [4, 5, 9, _, _, _, _], [1, 2, A, _, _, _, _], [6, 7, B, _, _, _, _]],
        [[0, 2, 4, 5, A, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 5, 8, A, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 6, 8, 9, B, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 5, 6, 9, B, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 5, 6, 8, A, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 5, 6, A, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 6, 7, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 6, 7, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 6, 7, 9, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[6, 7, 8, 9, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 6, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 6, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 6, 8, 9, A], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 6, 9, A, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [1, 3, 6, 7, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 6, 7, 8, A, _], [4, 5, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 5, 6, 7, A], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 6, 7, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 5, 6, 8, 9, A], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 5, 6, 9, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 8, 9, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 7, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [4, 7, 8, _, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 7, 9, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 6, 9, A, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [4, 6, 9, A, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 6, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 6, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[6, 7, 8, 9, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 6, 7, 9, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 6, 7, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 6, 7, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, B, _, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 8, B, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [2, 3, B, _, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 8, 9, B, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [2, 3, B, _, _, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 7, B, _, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [4, 7, 8, _, _, _, _], [2, 3, B, _, _, _, _], [5, 6, A, _, _, _, _]],
        [[1, 2, 4, 7, 9, B, _], [5, 6, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 6, 9, A, _, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 8, B, _, _, _], [4, 6, 9, A, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 6, A, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 6, 8, A, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[6, 7, 8, 9, A, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 6, 7, 9, A, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 6, 7, 8, A, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 6, 7, A, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 5, 6, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [1, 2, 5, 6, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 5, 6, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 5, 6, 8, 9, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [1, 2, 5, 6, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 7, _, _, _], [1, 2, 5, 6, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 5, 6, 9, _, _], [4, 7, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 5, 6, 7, 9], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 6, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [1, 2, 4, 6, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 6, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 6, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 6, 7, 8, 9, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 2, 3, 6, 7, 9], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 6, 7, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 6, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 5, 6, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 5, 6, 8, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 5, 6, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 6, 8, 9, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [1, 3, 5, 6, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 5, 6, 7, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 5, 6, 9, B, _], [4, 7, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 6, 7, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 6, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 6, 8, 9, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 6, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 6, 8, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 6, 7, 8, 9, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 6, 7, 8, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[6, 7, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 7, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [5, 7, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [5, 7, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 8, 9, _, _, _], [5, 7, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 8, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 5, A, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [4, 5, 8, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 5, 9, A, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 9, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [4, 7, 9, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 7, A, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 7, 8, A, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[8, 9, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 9, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 8, A, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, A, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 5, 7, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 5, 7, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [2, 3, 5, 7, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 5, 7, 8, 9, A], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 5, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 5, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [2, 3, 4, 5, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 5, 9, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 7, 9, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 7, 8, 9, A], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 2, 3, 4, 7, A], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 8, 9, A, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 9, A, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 2, 3, 8, A, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, A, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 5, 7, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [1, 2, 5, 7, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 5, 7, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 5, 7, 8, 9, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 5, 8, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 2, 3, 4, 5, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 5, 8, 9, B], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 4, 7, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [1, 2, 4, 7, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 4, 7, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, 4, 7, 8, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 2, 8, 9, B, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 2, 3, 9, B, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 2, 8, B, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[2, 3, B, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 5, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 5, 7, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 5, 7, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[5, 7, 8, 9, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 5, 8, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 5, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 5, 8, 9, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 5, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 4, 7, 9, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 4, 7, 8, 9, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 4, 7, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[4, 7, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[1, 3, 8, 9, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 1, 9, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[0, 3, 8, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
        [[_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _], [_, _, _, _, _, _, _]],
    ], dtype=torch.long).to(device)


@lru_cache(maxsize=64)
def _get_num_vd_table(device: torch.device) -> Int64[Tensor, "256"]:
    return torch.tensor([
        [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1],
        [1, 1, 2, 1, 2, 1, 3, 1, 2, 2, 2, 1, 2, 1, 2, 1],
        [1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1],
        [1, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1],
        [1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1],
        [1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1],
        [2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, 2, 2, 2, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1],
        [1, 2, 2, 2, 2, 2, 3, 2, 1, 2, 1, 1, 1, 1, 1, 1],
        [2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1],
        [1, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1],
        [1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1],
        [1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1],
        [1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
    ], dtype=torch.long).flatten().to(device)


@dataclass
class FlexiCubes(TensorDataclass):

    num_vertices: int = Size.Dynamic
    num_cubes: int = Size.Dynamic

    vertices: Tensor = Float[num_vertices, 3]
    sdf_values: Tensor = Float[num_vertices, 1]
    indices: Tensor = Long[num_cubes, 8]
    resolution: Tensor = Long[3]

    alpha: Optional[Tensor] = Float[num_cubes, 8]
    '''
    Weight parameters for the cube corners to adjust dual vertices positioning.
    Defaults to uniform value for all vertices.
    '''

    beta: Optional[Tensor] = Float[num_cubes, 12]
    '''
    Weight parameters for the cube edges to adjust dual vertices positioning.
    Defaults to uniform value for all edges.
    '''

    gamma: Optional[Tensor] = Float[num_cubes, 1]
    '''
    Weight parameters to control the splitting of quadrilaterals into triangles.
    Defaults to uniform value for all cubes.
    '''

    @classmethod
    def from_resolution(
        cls,
        *resolution: int,
        device: Optional[torch.device] = None,
        random_sdf: bool = True,
        scale: float = 1.0,
    ) -> FlexiCubes:
        """
        Generates a voxel grid based on the specified resolution.

        Args:
            resolution (int or list[int]): The resolution of the voxel grid. If an integer
                is provided, it is used for all three dimensions. If a list or tuple 
                of 3 integers is provided, they define the resolution for the x, 
                y, and z dimensions respectively.

        Returns:
            (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the 
                cube corners (index into vertices) of the constructed voxel grid. 
                The vertices are centered at the origin, with the length of each 
                dimension in the grid being one.
        """
        assert len(resolution) in [1, 3]
        if len(resolution) == 1:
            resolution = (resolution[0], resolution[0], resolution[0])
        voxel_grid_template = torch.ones(
            resolution[0] + 1,
            resolution[1] + 1,
            resolution[2] + 1,
            device=device,
        ) # [R + 1, R + 1, R + 1]

        cube_corners = torch.tensor([
            [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
            [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1],
        ], dtype=torch.long, device=device) # [8, 3]

        res = torch.tensor(resolution, dtype=torch.long, device=device) # [3]
        coords = torch.nonzero(voxel_grid_template).float() # [R+ * R+ * R+, 3]
        verts = coords.reshape(-1, 3) / res # [R+ * R+ * R+, 3]
        cubes = torch.arange(resolution[0] * resolution[1] * resolution[2], device=device) # [R*R*R]
        cubes = torch.stack((
            cubes % resolution[0],
            (cubes // resolution[0]) % resolution[1],
            cubes // (resolution[1] * resolution[0]),
        ), dim=-1)[:, None, :] + cube_corners # [R*R*R, 8, 3]
        cubes = (cubes[..., 2] * (1 + resolution[1]) + cubes[..., 1]) * (1 + resolution[0]) + cubes[..., 0] # [R * R * R, 8]

        sdfs = (
            (torch.rand_like(verts[..., 0:1]) - 0.1)
            if random_sdf
            else torch.zeros_like(verts[..., 0:1])
        )

        return FlexiCubes(
            vertices=(2 * verts - 1) * scale,
            indices=cubes,
            sdf_values=sdfs,
            resolution=res,
        )

    @torch.no_grad()
    def _get_case_id(
        self,
        occupancy: Bool[Tensor, "F 8"],
        surf_cubes: Bool[Tensor, "F"],
    ) -> Int64[Tensor, "N"]:
        """
        Obtains the ID of topology cases based on cell corner occupancy. This function resolves the 
        ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the 
        supplementary material. It should be noted that this function assumes a regular grid.
        """
        res = self.resolution.tolist()
        cube_corners_idx = torch.pow(2, torch.arange(8, device=occupancy.device)) # [8]
        case_ids = (occupancy[surf_cubes, :] * cube_corners_idx).sum(-1) # [N]

        problem_config = _get_check_table(occupancy.device)[case_ids, :] # [N, 5]
        to_check = problem_config[..., 0] == 1                           # [N]
        problem_config = problem_config[to_check, :]                     # [P, 5]

        # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
        # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
        # This allows efficient checking on adjacent cubes.
        problem_config_full = problem_config.new_zeros(res + [5]) # [R, R, R, 5]
        vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # [R * R * R, 3] = [F, 3]
        vol_idx_problem = vol_idx[surf_cubes, :][to_check, :] # [P, 3]
        problem_config_full[
            vol_idx_problem[..., 0],
            vol_idx_problem[..., 1],
            vol_idx_problem[..., 2],
        ] = problem_config
        vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] # [P, 3]

        within_range = (vol_idx_problem_adj >= 0).all(-1) & (vol_idx_problem_adj < self.resolution).all(-1) # [P]

        vol_idx_problem = vol_idx_problem[within_range]
        vol_idx_problem_adj = vol_idx_problem_adj[within_range]
        problem_config = problem_config[within_range]
        problem_config_adj = problem_config_full[
            vol_idx_problem_adj[..., 0],
            vol_idx_problem_adj[..., 1],
            vol_idx_problem_adj[..., 2]
        ]
        # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
        to_invert = (problem_config_adj[..., 0] == 1)
        idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
        case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
        return case_ids

    @torch.no_grad()
    def _identify_surf_edges(self, surf_cubes: Float32[Tensor, "F"]) -> Tuple[
        Float32[Tensor, "E 2"],
        Int64[Tensor, "F*12"],
        Int64[Tensor, "F*12"],
        Bool[Tensor, "F*12"],
    ]:
        """
        Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge 
        can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge 
        and marks the cube edges with this index.
        """
        occupancy = self.sdf_values < 0 # [V, 1]
        all_edges = self.indices[surf_cubes][:, _get_cube_edges(surf_cubes.device)].view(-1, 2) # [F * 12, 2]
        (
            unique_edges, # [U, 2]
            _idx_map,     # [F * 12]
            counts,       # [F * 12]
        ) = all_edges.unique(dim=0, return_inverse=True, return_counts=True)

        mask_edges = occupancy[unique_edges.flatten()].view(-1, 2).sum(-1) == 1 # [E]

        surf_edges_mask = mask_edges[_idx_map] # [F * 12]
        counts = counts[_idx_map]              # [F * 12]

        mapping = -self.indices.new_ones(unique_edges.shape[0])
        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=mapping.device)
        # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
        # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
        idx_map = mapping[_idx_map]              # [F * 12]
        surf_edges = unique_edges[mask_edges, :] # [E, 2]
        return surf_edges, idx_map, counts, surf_edges_mask

    def _linear_interp(
        self,
        edges_weight: Float32[Tensor, "*bs 2 1"],
        edges_x: Float32[Tensor, "*bs 2 3"],
        *,
        sdf_eps: Optional[float]
    ) -> Float32[Tensor, "*bs 3"]:
        """
        Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
        """
        v_a = edges_x[..., 0, :]                           # [..., 3]
        v_b = edges_x[..., 1, :]                           # [..., 3]
        sdf_a = edges_weight[..., 0, :]                    # [..., 1]
        sdf_b = edges_weight[..., 1, :]                    # [..., 1]
        w_b = sdf_a / (sdf_a - sdf_b)                      # [..., 1]
        if sdf_eps is not None:
            w_b = (1 - sdf_eps) * w_b + (sdf_eps / 2)      # [E, 1]
        return v_b * w_b + v_a * (1 - w_b)                 # [E, 3]

    def dual_marching_cubes(
        self,
        *,
        grad_func: Optional[Callable[[Tensor], Tensor]] = None,
        sdf_eps: Optional[float] = None,
        weight_scale: float = 0.99,
    ) -> Tuple[TriangleMesh, Float32[Tensor, "K"]]:
        r"""
        Main function for mesh extraction from scalar field using FlexiCubes. This function converts 
        discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, 
        to triangle or tetrahedral meshes using a differentiable operation as described in 
        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances 
        mesh quality and geometric fidelity by adjusting the surface representation based on gradient 
        optimization. The output surface is differentiable with respect to the input vertex positions, 
        scalar field values, and weight parameters.

        If you intend to extract a surface mesh from a fixed Signed Distance Field without the 
        optimization of parameters, it is suggested to provide the "grad_func" which should 
        return the surface gradient at any given 3D position. When grad_func is provided, the process 
        to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as 
        described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. 
        Please note, this approach is non-differentiable.

        For more details and example usage in optimization, refer to the 
        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.

        Args:
            grad_func (callable, optional): A function to compute the surface gradient at specified 
                3D positions (input: Nx3 positions). The function should return gradients as an Nx3 
                tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.

        Returns:
            (TriangleMesh, torch.Tensor): Tuple containing:
                - The extracted triangular mesh.
                - Regularizer L_dev, computed per dual vertex.

        .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
            https://research.nvidia.com/labs/toronto-ai/flexicubes/
        .. _Manifold Dual Contouring:
            https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
        """

        cube_occupancy = (self.sdf_values < 0)[self.indices.flatten()].view(-1, 8) # [N, 8]
        cube_occ_sum = cube_occupancy.sum(-1)                                      # [N]
        surf_cubes = (cube_occ_sum > 0) & (cube_occ_sum < 8)                       # [N]
        N = surf_cubes.sum().item()
        assert N > 0

        beta = (
            torch.ones(N, 12, device=surf_cubes.device)
            if self.beta is None
            else (self.beta[surf_cubes].tanh() * weight_scale + 1)
        ) # [N, 12]
        alpha = (
            None
            if self.alpha is None
            else (self.alpha[surf_cubes].tanh() * weight_scale + 1)
        ) # [N, 8]
        gamma = (
            torch.ones(N, 1, device=surf_cubes.device)
            if self.gamma is None
            else (self.gamma[surf_cubes].sigmoid() * weight_scale + (1 - weight_scale) / 2)
        ) # [N, 1]
        case_ids = self._get_case_id(cube_occupancy, surf_cubes) # [N]
        (
            surf_edges,      # [E, 2]
            idx_map,         # [N * 12]
            edge_counts,     # [N * 12]
            surf_edges_mask  # [N * 12]
        ) = self._identify_surf_edges(surf_cubes)

        """
        Computes the location of dual vertices as described in Section 4.2
        """
        if alpha is not None:
            alpha = alpha.index_select(index=_get_cube_edges(surf_edges.device), dim=1).view(-1, 2) # [N * 12, 2]
        surf_edges_x = self.vertices.index_select(index=surf_edges.view(-1), dim=0).view(-1, 2, 3) # [E, 2, 3]
        surf_edges_s = self.sdf_values.index_select(index=surf_edges.view(-1), dim=0).view(-1, 2, 1) # [E, 2, 1]
        zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x, sdf_eps=sdf_eps) # [E, 3]

        idx_map = idx_map.view(-1, 12) # [N, 12]
        num_vd = _get_num_vd_table(idx_map.device).index_select(index=case_ids, dim=0) # [N]
        edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []

        total_num_vd = 0
        vd_idx_map = idx_map.new_zeros(case_ids.shape[0], 12) # [N, 12]
        if grad_func is not None:
            raise NotImplementedError
        dmc_table = _get_dmc_table(idx_map.device) # [256, 4, 7]
        for num in range(5):
            cur_cubes: Tensor = (num_vd == num) # [N], consider cubes with the same numbers of vd emitted (for batching)
            curr_num_vd: int = cur_cubes.sum().item() * num
            if curr_num_vd == 0:
                continue
            curr_edge_group = dmc_table[case_ids[cur_cubes], :num, :].view(-1, num * 7) # [G, num * 7]
            curr_edge_group_to_vd = torch.arange(curr_num_vd, device=self.device) + total_num_vd # [G * num]
            curr_edge_group_to_vd = curr_edge_group_to_vd[:, None].repeat(1, 7).view_as(curr_edge_group) # [G, num * 7]
            total_num_vd += curr_num_vd
            curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[cur_cubes] # [G]
            curr_edge_group_to_cube = curr_edge_group_to_cube.unsqueeze(-1).expand_as(curr_edge_group) # [G, num * 7]

            curr_mask = (curr_edge_group != -1) # [G, num * 7]
            edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) # [G']
            edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd, curr_mask)) # [G']
            edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) # [G']
            vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) # [G * num, 1]
            vd_gamma.append(gamma[cur_cubes, :].repeat(1, num).reshape(-1)) # [G * num]

        edge_group = torch.cat(edge_group)                            # [K]
        edge_group_to_vd = torch.cat(edge_group_to_vd)                # [K]
        edge_group_to_cube = torch.cat(edge_group_to_cube)            # [K]
        vd_num_edges = torch.cat(vd_num_edges, dim=0)                 # [Q, 1]
        vd_gamma = torch.cat(vd_gamma, dim=0)                         # [Q, 1]
        edge_group_idx = edge_group_to_cube * 12 + edge_group         # [K]

        vd = torch.zeros((total_num_vd, 3), device=self.device)       # [Q, 3]
        beta_sum = torch.zeros((total_num_vd, 1), device=self.device) # [Q, 1]

        idx_group = idx_map.flatten()[edge_group_idx]       # [K]
        x_group = surf_edges_x[idx_group, ...]              # [K, 2, 3]
        s_group = surf_edges_s[idx_group, ...]              # [K, 2, 1]
        zero_crossing_group = zero_crossing[idx_group, ...] # [K, 3]

        if alpha is not None:
            alpha_group = alpha.index_select(dim=0, index=edge_group_idx).view(-1, 2, 1)
            ue_group = self._linear_interp(s_group * alpha_group, x_group, sdf_eps=sdf_eps) # [K, 3]
        else:
            ue_group = self._linear_interp(s_group, x_group, sdf_eps=sdf_eps) # [K, 3]

        beta_group = beta.flatten()[edge_group_idx].unsqueeze(-1)                               # [K, 1]
        beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)            # [Q, 1]
        vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum  # [Q, 3]
        L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) # [K]

        v_idx = torch.arange(vd.shape[0], device=self.device) # [N]

        vd_idx_map = vd_idx_map.flatten().scatter(
            dim=0,
            index=edge_group_idx,                     # [K]
            src=v_idx[edge_group_to_vd],              # [K]
        )                                             # [N * 12]

        vertices, faces, _, _ = self._triangulate(
            surf_edges,                        # [E, 2]
            vd,                                # [Q, 3]
            vd_gamma,                          # [Q]
            edge_counts,                       # [N * 12]
            idx_map,                           # [N, 12]
            vd_idx_map,                        # [N * 12]
            surf_edges_mask,                   # [N * 12]
            grad_func,
        )
        if torch.is_anomaly_enabled():
            assert vertices.isfinite().all()
        return TriangleMesh(vertices=vertices, indices=faces), L_dev

    def compute_entropy(self) -> Float32[Tensor, "1"]:
        all_edges = self.indices[:, _get_cube_edges(self.device)].view(-1, 2)  # [F * 12, 2]
        unique_edges = all_edges.unique(dim=0)                                 # [E, 2]
        occupancy = (self.sdf_values < 0).squeeze(-1)                          # [V]
        valid = occupancy[unique_edges[:, 0]] != occupancy[unique_edges[:, 1]] # [E]
        sdf_a = self.sdf_values[unique_edges[valid, 0]]                        # [E]
        sdf_b = self.sdf_values[unique_edges[valid, 1]]                        # [E]
        return torch.add(
            F.binary_cross_entropy_with_logits(sdf_a, (sdf_b > 0).float()),
            F.binary_cross_entropy_with_logits(sdf_b, (sdf_a > 0).float())
        )

    def _compute_reg_loss(
        self,
        vd: Float32[Tensor, "Q 3"],
        ue: Float32[Tensor, "K 3"],
        edge_group_to_vd: Int64[Tensor, "K"],
        vd_num_edges: Int64[Tensor, "Q 1"],
    ) -> Float32[Tensor, "K"]:
        """
        Regularizer L_dev as in Equation 8
        """
        dist = (ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0)).norm(dim=-1)
        mean_l2 = torch.zeros_like(vd[:, 0])
        mean_l2 = mean_l2.index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
        mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
        return mad

    def _triangulate(
        self,
        surf_edges,
        vd,
        vd_gamma,
        edge_counts,
        idx_map,
        vd_idx_map,
        surf_edges_mask,
        grad_func,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into 
        triangles based on the gamma parameter, as described in Section 4.3.
        """
        s_n = self.sdf_values
        with torch.no_grad():
            group_mask = (edge_counts == 4) & surf_edges_mask  # surface edges shared by 4 cubes.
            group = idx_map.reshape(-1)[group_mask]
            vd_idx = vd_idx_map[group_mask]
            edge_indices, indices = torch.sort(group, stable=True)
            quad_vd_idx = vd_idx[indices].reshape(-1, 4)

            # Ensure all face directions point towards the positive SDF to maintain consistent winding.
            s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
            flip_mask = s_edges[:, 0] > 0
            quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
                                     quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
        if grad_func is not None:
            # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
            with torch.no_grad():
                vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
                quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
                gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
                gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
        else:
            quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
            gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).unsqueeze(-1)
            gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).unsqueeze(-1)
        # if not s_n.requires_grad:
        #     mask = (gamma_02 > gamma_13).squeeze(1)
        #     faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
        #     faces[mask] = quad_vd_idx[mask][:, [0, 1, 2, 0, 2, 3]]
        #     faces[~mask] = quad_vd_idx[~mask][:, [0, 1, 3, 3, 1, 2]]
        #     faces = faces.reshape(-1, 3)
        # else:
        if True:
            vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
            vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
                     torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
            vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
                     torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
            weight_sum = (gamma_02 + gamma_13) + 1e-8
            vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
                         weight_sum.unsqueeze(-1)).squeeze(1)
            vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
            vd = torch.cat([vd, vd_center])
            faces = quad_vd_idx[:, [0, 1, 1, 2, 2, 3, 3, 0]].reshape(-1, 4, 2)
            faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
        return vd, faces, s_edges, edge_indices
